/****************************************************************************
*
*    Copyright (c) 2020 Vivante Corporation
*
*    Permission is hereby granted, free of charge, to any person obtaining a
*    copy of this software and associated documentation files (the "Software"),
*    to deal in the Software without restriction, including without limitation
*    the rights to use, copy, modify, merge, publish, distribute, sublicense,
*    and/or sell copies of the Software, and to permit persons to whom the
*    Software is furnished to do so, subject to the following conditions:
*
*    The above copyright notice and this permission notice shall be included in
*    all copies or substantial portions of the Software.
*
*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
*    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
*    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
*    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
*    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
*    DEALINGS IN THE SOFTWARE.
*
*****************************************************************************/
#ifndef _VSI_NN_OP_LSTMUNIT_OVXLIB_H
#define _VSI_NN_OP_LSTMUNIT_OVXLIB_H

#include "vsi_nn_tensor.h"
#include "vsi_nn_types.h"
#include "vsi_nn_op_lstmunit.h"

#ifdef __cplusplus
extern "C" {
#endif

#define LSTMUNIT_IFCO_GATE_COUNT 4

/* enum for inputs/outputs */
enum
{
    LSTMUNIT_INPUT_INPUT        = 0,
    LSTMUNIT_INPUT_H_STATE      = 1,
    LSTMUNIT_INPUT_C_STATE      = 2,

    LSTMUNIT_INPUT_WEIGHT_I2I   = 3,
    LSTMUNIT_INPUT_WEIGHT_I2F   = 4,
    LSTMUNIT_INPUT_WEIGHT_I2C   = 5,
    LSTMUNIT_INPUT_WEIGHT_I2O   = 6,

    LSTMUNIT_INPUT_WEIGHT_R2I   = 7,
    LSTMUNIT_INPUT_WEIGHT_R2F   = 8,
    LSTMUNIT_INPUT_WEIGHT_R2C   = 9,
    LSTMUNIT_INPUT_WEIGHT_R2O   = 10,

    LSTMUNIT_INPUT_WEIGHT_C2I   = 11,
    LSTMUNIT_INPUT_WEIGHT_C2F   = 12,
    LSTMUNIT_INPUT_WEIGHT_C2O   = 13,

    LSTMUNIT_INPUT_BIAS_I       = 14,
    LSTMUNIT_INPUT_BIAS_F       = 15,
    LSTMUNIT_INPUT_BIAS_C       = 16,
    LSTMUNIT_INPUT_BIAS_O       = 17,

    LSTMUNIT_INPUT_WEIGHT_PROJ  = 18,
    LSTMUNIT_INPUT_BIAS_PROJ    = 19,

    LSTMUNIT_INPUT_LAYERNORM_I  = 20,
    LSTMUNIT_INPUT_LAYERNORM_F  = 21,
    LSTMUNIT_INPUT_LAYERNORM_C  = 22,
    LSTMUNIT_INPUT_LAYERNORM_O  = 23,

    LSTMUNIT_INPUT_AUX_INPUT      = 24,
    LSTMUNIT_INPUT_AUX_WEIGHT_I2I = 25,
    LSTMUNIT_INPUT_AUX_WEIGHT_I2F = 26,
    LSTMUNIT_INPUT_AUX_WEIGHT_I2C = 27,
    LSTMUNIT_INPUT_AUX_WEIGHT_I2O = 28,

    LSTMUNIT_INPUT_BIAS_R2I       = 29,
    LSTMUNIT_INPUT_BIAS_R2F       = 30,
    LSTMUNIT_INPUT_BIAS_R2C       = 31,
    LSTMUNIT_INPUT_BIAS_R2O       = 32,

    LSTMUNIT_INPUT_CNT,

    LSTMUNIT_OUTPUT_OUTPUT      = 0,
    LSTMUNIT_OUTPUT_H_STATE     = 1,
    LSTMUNIT_OUTPUT_C_STATE     = 2,
    LSTMUNIT_OUTPUT_SCRATCH     = 3,

    LSTMUNIT_OUTPUT_CNT
};

typedef int32_t vsi_nnlstmunit_ovxlib_internal_node_index_t; enum
{
    LSTMUNIT_NODE_INVALID = -1,
    /* Add internal node def */

    LSTMUNIT_NODE_RESHAPE_INPUT_FC_INPUT,
    LSTMUNIT_NODE_TRANS_INPUT_FC_INPUT,
    LSTMUNIT_NODE_FC_I2I,
    LSTMUNIT_NODE_FC_I2F,
    LSTMUNIT_NODE_FC_I2C,
    LSTMUNIT_NODE_FC_I2O,

    LSTMUNIT_NODE_TEST_NODE,

    LSTMUNIT_NODE_RESHAPE_RECURRENT_FC_INPUT,
    LSTMUNIT_NODE_TRANS_RECURRENT_FC_INPUT,

    LSTMUNIT_NODE_FC_R2I,
    LSTMUNIT_NODE_FC_R2F,
    LSTMUNIT_NODE_FC_R2C,
    LSTMUNIT_NODE_FC_R2O,

    LSTMUNIT_NODE_NN_TRANSPOSE_I2I,
    LSTMUNIT_NODE_NN_TRANSPOSE_I2F,
    LSTMUNIT_NODE_NN_TRANSPOSE_I2C,
    LSTMUNIT_NODE_NN_TRANSPOSE_I2O,
    LSTMUNIT_NODE_NN_RESHAPE_I2I,
    LSTMUNIT_NODE_NN_RESHAPE_I2F,
    LSTMUNIT_NODE_NN_RESHAPE_I2C,
    LSTMUNIT_NODE_NN_RESHAPE_I2O,
    LSTMUNIT_NODE_NN_TRANSPOSE_R2I,
    LSTMUNIT_NODE_NN_TRANSPOSE_R2F,
    LSTMUNIT_NODE_NN_TRANSPOSE_R2C,
    LSTMUNIT_NODE_NN_TRANSPOSE_R2O,
    LSTMUNIT_NODE_NN_RESHAPE_R2I,
    LSTMUNIT_NODE_NN_RESHAPE_R2F,
    LSTMUNIT_NODE_NN_RESHAPE_R2C,
    LSTMUNIT_NODE_NN_RESHAPE_R2O,

    LSTMUNIT_NODE_INPUT_FC_OUTPUTS_CONCAT,
    LSTMUNIT_NODE_RECURRENT_FC_OUTPUTS_CONCAT,
    LSTMUNIT_NODE_LAYER_NORM,
    LSTMUNIT_NODE_LAYER_NORM_SPLIT,

    LSTMUNIT_NODE_LAYER_NORM_I,
    LSTMUNIT_NODE_LAYER_NORM_F,
    LSTMUNIT_NODE_LAYER_NORM_C,
    LSTMUNIT_NODE_LAYER_NORM_O,

    LSTMUNIT_NODE_ACTIVATIONS, /* Activations */

    LSTMUNIT_NODE_RESHAPE_PROJECTION_FC_INPUT,
    LSTMUNIT_NODE_FC_PROJ,
    LSTMUNIT_NODE_ADD_PROJ,
    LSTMUNIT_NODE_RESHAPE_FC_PROJ,

    LSTMUNIT_NODE_CNT
};

enum
{
    LSTMUNIT_TENSOR_RESHAPRE_INPUT_FC_INPUT,
    LSTMUNIT_TENSOR_TRANS_INPUT_FC_INPUT,

    LSTMUNIT_TENSOR_ZERO_BIAS_I2I,
    LSTMUNIT_TENSOR_ZERO_BIAS_I2F,
    LSTMUNIT_TENSOR_ZERO_BIAS_I2C,
    LSTMUNIT_TENSOR_ZERO_BIAS_I2O,

    LSTMUNIT_TENSOR_RESHAPRE_RECURRENT_FC_INPUT,
    LSTMUNIT_TENSOR_TRANS_RECURRENT_FC_INPUT,

    LSTMUNIT_TENSOR_ZERO_BIAS_R2I,
    LSTMUNIT_TENSOR_ZERO_BIAS_R2F,
    LSTMUNIT_TENSOR_ZERO_BIAS_R2C,
    LSTMUNIT_TENSOR_ZERO_BIAS_R2O,

    LSTMUNIT_TENSOR_CONCATED_BIAS,
    LSTMUNIT_TENSOR_CONCATED_LN_W,

    LSTMUNIT_TENSOR_OUTPUT_I2I,
    LSTMUNIT_TENSOR_OUTPUT_I2F,
    LSTMUNIT_TENSOR_OUTPUT_I2C,
    LSTMUNIT_TENSOR_OUTPUT_I2O,
    LSTMUNIT_TENSOR_OUTPUT_R2I,
    LSTMUNIT_TENSOR_OUTPUT_R2F,
    LSTMUNIT_TENSOR_OUTPUT_R2C,
    LSTMUNIT_TENSOR_OUTPUT_R2O,

    LSTMUNIT_TENSOR_OUTPUT_NN_I2I,
    LSTMUNIT_TENSOR_OUTPUT_NN_I2F,
    LSTMUNIT_TENSOR_OUTPUT_NN_I2C,
    LSTMUNIT_TENSOR_OUTPUT_NN_I2O,
    LSTMUNIT_TENSOR_OUTPUT_NN_R2I,
    LSTMUNIT_TENSOR_OUTPUT_NN_R2F,
    LSTMUNIT_TENSOR_OUTPUT_NN_R2C,
    LSTMUNIT_TENSOR_OUTPUT_NN_R2O,

    LSTMUNIT_TENSOR_OUTPUT_NN_TRANS_I2I,
    LSTMUNIT_TENSOR_OUTPUT_NN_TRANS_I2F,
    LSTMUNIT_TENSOR_OUTPUT_NN_TRANS_I2C,
    LSTMUNIT_TENSOR_OUTPUT_NN_TRANS_I2O,
    LSTMUNIT_TENSOR_OUTPUT_NN_TRANS_R2I,
    LSTMUNIT_TENSOR_OUTPUT_NN_TRANS_R2F,
    LSTMUNIT_TENSOR_OUTPUT_NN_TRANS_R2C,
    LSTMUNIT_TENSOR_OUTPUT_NN_TRANS_R2O,

    LSTMUNIT_TENSOR_RESHAPED_WEIGHT_I2I,
    LSTMUNIT_TENSOR_RESHAPED_WEIGHT_I2F,
    LSTMUNIT_TENSOR_RESHAPED_WEIGHT_I2C,
    LSTMUNIT_TENSOR_RESHAPED_WEIGHT_I2O,
    LSTMUNIT_TENSOR_RESHAPED_WEIGHT_R2I,
    LSTMUNIT_TENSOR_RESHAPED_WEIGHT_R2F,
    LSTMUNIT_TENSOR_RESHAPED_WEIGHT_R2C,
    LSTMUNIT_TENSOR_RESHAPED_WEIGHT_R2O,

    LSTMUNIT_TENSOR_INPUT_FC_OUTPUTS,
    LSTMUNIT_TENSOR_RECURRENT_FC_OUTPUTS,
    LSTMUNIT_TENSOR_LAYER_NORM_OUTPUT,

    LSTMUNIT_TENSOR_LAYER_NORM_OUTPUT_I,
    LSTMUNIT_TENSOR_LAYER_NORM_OUTPUT_F,
    LSTMUNIT_TENSOR_LAYER_NORM_OUTPUT_C,
    LSTMUNIT_TENSOR_LAYER_NORM_OUTPUT_O,

    LSTMUNIT_TENSOR_ACTIVATION_OUTPUT,
    LSTMUNIT_TENSOR_RESHAPE_PROJECTION_FC_INPUT, /* reshape projection input */
    LSTMUNIT_TENSOR_ZERO_BIAS_PROJECTION,
    LSTMUNIT_TENSOR_PROJECTION_FC_NN_OUTPUT,
    LSTMUNIT_TENSOR_PROJECTION_FC_OUTPUT,

    LSTMUNIT_TENSOR_CNT
};

enum
{
    LSTMUNIT_QUANTIZE_PARAM_I2I,
    LSTMUNIT_QUANTIZE_PARAM_I2F,
    LSTMUNIT_QUANTIZE_PARAM_I2C,
    LSTMUNIT_QUANTIZE_PARAM_I2O,

    LSTMUNIT_QUANTIZE_PARAM_R2I,
    LSTMUNIT_QUANTIZE_PARAM_R2F,
    LSTMUNIT_QUANTIZE_PARAM_R2C,
    LSTMUNIT_QUANTIZE_PARAM_R2O,

    LSTMUNIT_QUANTIZE_PARAM_COUNT
};

enum
{
    LSTMUNIT_QUANTIZE_PARAM_AUX_I2I,
    LSTMUNIT_QUANTIZE_PARAM_AUX_I2F,
    LSTMUNIT_QUANTIZE_PARAM_AUX_I2C,
    LSTMUNIT_QUANTIZE_PARAM_AUX_I2O,

    LSTMUNIT_QUANTIZE_PARAM_AUX_COUNT
};

typedef struct _vsi_nn_lstmunit_ovxlib_lcl_data_t
{
    vsi_bool use_cifg;
    vsi_bool use_layer_norm;
    vsi_bool use_projection;
    vsi_bool use_projection_bias;
    vsi_bool use_hybrid;
    vsi_bool multi_batch;
    vsi_bool use_peephole;
} vsi_nn_lstmunit_ovxlib_lcl_data_t;

typedef struct _vsi_nn_lstmunit_ovxlib_param
{
    union
    {
        vsi_nn_lstmunit_ovxlib_lcl_data_t *local;
        struct { /* for ABI compatible */
            vsi_bool pad0;
            vsi_bool pad1;
            vsi_bool pad2;
            vsi_bool pad3;
            vsi_bool pad4;
            vsi_bool pad5;
        };
    };

    float cell_clip;
    float proj_clip;
    vsi_nn_activation_e activation;
    float forget_bias;
    vsi_nn_dtype_t internal_dtype[LSTMUNIT_QUANTIZE_PARAM_COUNT];
    vsi_nn_activation_e recurrent_activation;
    vsi_nn_dtype_t *internal_dtype_aux;
} vsi_nn_lstmunit_ovxlib_param;

#ifdef __cplusplus
}
#endif

#endif
